#!/usr/bin/env python3
"""
steer_eval_soft.py
"""

from __future__ import annotations
import argparse, json, math, warnings, re, inspect, os
from functools import lru_cache
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from tqdm.auto import tqdm

# ---------------- Canonical tags -----------------
FOUR_TAGS = [
    "final_answer",
    "setup_and_retrieval",
    "analysis_and_computation",
    "uncertainty_and_verification",
]
TAG2ID = {t: i for i, t in enumerate(FOUR_TAGS)}

# ==================== Determinism ====================
def set_global_determinism(seed: int, strict: bool = False):
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    if strict:
        os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
        torch.use_deterministic_algorithms(True, warn_only=True)
        try:
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = False; cudnn.deterministic = True
        except Exception:
            pass
    else:
        torch.use_deterministic_algorithms(False)

# ==================== Distributed ====================
def _dist_is_enabled() -> bool:
    try:
        import torch.distributed as dist
        return dist.is_available() and dist.is_initialized()
    except Exception:
        return False

def _dist_init(backend: str = "nccl"):
    if torch.cuda.is_available():
        local_rank = int(os.environ.get("LOCAL_RANK", "0"))
        torch.cuda.set_device(local_rank)
    import torch.distributed as dist
    if dist.is_available() and not dist.is_initialized():
        try:
            dist.init_process_group(backend=backend)
        except Exception:
            dist.init_process_group(backend="gloo")

def _dist_rank() -> int:
    if not _dist_is_enabled(): return 0
    import torch.distributed as dist
    return dist.get_rank()

def _dist_world() -> int:
    if not _dist_is_enabled(): return 1
    import torch.distributed as dist
    return dist.get_world_size()

def _dist_local_rank() -> int:
    return int(os.environ.get("LOCAL_RANK", "0"))

def _only_rank0() -> bool:
    return _dist_rank() == 0

def _shard_list(xs, rank: int, world: int):
    n = len(xs)
    if world <= 1 or n == 0:
        return xs
    per = math.ceil(n / world)
    start = rank * per
    end = min(n, start + per)
    return xs[start:end]

# ==================== Preproc ====================
def load_preproc(model_npz_path: str):
    z = np.load(model_npz_path, allow_pickle=True)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    scaler.mean_  = z["prep_mean"]
    scaler.scale_ = z["prep_scale"]
    scaler.var_   = scaler.scale_ ** 2
    scaler.n_features_in_ = scaler.mean_.shape[0]
    pca = None
    if "prep_pca_components" in z.files and z["prep_pca_components"].size > 0:
        from sklearn.decomposition import PCA
        comps = z["prep_pca_components"]; mean = z["prep_pca_mean"]
        k = int(comps.shape[0]); Din = int(mean.shape[0])
        pca = PCA(n_components=k, svd_solver="full")
        pca.components_ = comps
        pca.mean_ = mean
        pca.n_features_in_ = Din
        pca.explained_variance_ = z.get("prep_pca_explained_variance", np.ones(k))
        pca.explained_variance_ratio_ = z.get("prep_pca_explained_variance_ratio", np.ones(k)/k)
        pca.singular_values_ = z.get("prep_pca_singular_values", np.ones(k))
    return scaler, pca

def invert_preproc_step(z_row: np.ndarray, scaler, pca) -> np.ndarray:
    """Undo PCA + StandardScaler: (z -> x_hat) in the model's hidden space."""
    x = z_row
    if pca is not None:
        x = pca.inverse_transform(x[None, :])[0]
    return x * scaler.scale_ + scaler.mean_

# -------- Nemotron / Qwen chat helpers --------
def nemotron_build_prompt(tokenizer, system_text: str, user_text: str) -> str:
    msgs = [{"role": "system", "content": system_text},
            {"role": "user",   "content": user_text}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def qwen_build_prompt(tokenizer, user_text: str, enable_thinking: bool) -> str:
    msgs = [{"role": "user", "content": user_text}]
    try:
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking
        )
    except TypeError:
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

# --------------- Model loader & utilities --------------
def _resolve_dtype(name: str):
    return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name]

def _pick_device(device_arg: str) -> str:
    if _dist_is_enabled():
        return f"cuda:{_dist_local_rank()}" if torch.cuda.is_available() else "cpu"
    if device_arg == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device_arg

def _model_device(mdl: torch.nn.Module):
    return next(mdl.parameters()).device

@lru_cache(maxsize=8)
def load_tok_mdl(model_name: str, tokenizer_name: Optional[str], device: str, dtype: str):
    from transformers import AutoTokenizer, AutoModelForCausalLM
    dev = _pick_device(device)
    target_dtype = _resolve_dtype(dtype)
    is_qwen = "qwen" in model_name.lower()
    tok = AutoTokenizer.from_pretrained(
        tokenizer_name or model_name,
        trust_remote_code=is_qwen
    )
    tok.padding_side = "left"
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token = tok.eos_token
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype=target_dtype,
        low_cpu_mem_usage=True,
        trust_remote_code=is_qwen
    )
    if dev == "cpu" and target_dtype is torch.float16:
        mdl = mdl.to(dtype=torch.float32)
    mdl = mdl.to(dev).eval()
    return tok, mdl

def _supports_top_k(mdl) -> bool:
    try:
        gc = getattr(mdl, "generation_config", None)
        if gc is not None and hasattr(gc, "top_k"): return True
    except Exception:
        pass
    try:
        sig = inspect.signature(mdl.generate)
        if "top_k" in sig.parameters: return True
    except Exception:
        pass
    return False

def _supports_min_p(mdl) -> bool:
    try:
        gc = getattr(mdl, "generation_config", None)
        if gc is not None and hasattr(gc, "min_p"): return True
    except Exception:
        pass
    try:
        sig = inspect.signature(mdl.generate)
        if "min_p" in sig.parameters: return True
    except Exception:
        pass
    return False

# -------- Hidden width inference + validation --------
def _infer_hidden_width(mdl: torch.nn.Module) -> int:
    cfg = getattr(mdl, "config", None)
    if cfg is not None and hasattr(cfg, "hidden_size") and cfg.hidden_size:
        return int(cfg.hidden_size)
    try:
        emb = mdl.get_input_embeddings()
        if emb is not None:
            if hasattr(emb, "embedding_dim") and emb.embedding_dim:
                return int(emb.embedding_dim)
            w = getattr(emb, "weight", None)
            if w is not None and w.ndim == 2:
                return int(w.shape[1])
    except Exception:
        pass
    for p in mdl.parameters():
        if p.ndim == 2:
            return int(p.shape[1])
    raise RuntimeError("Could not infer model hidden width.")

def _validate_vec_dim(vec_hidden: np.ndarray, mdl: torch.nn.Module, where: str):
    if vec_hidden is None:
        return
    if vec_hidden.ndim != 1:
        raise ValueError(f"[{where}] Steering vector must be 1-D, got shape={tuple(vec_hidden.shape)}")
    want = _infer_hidden_width(mdl)
    got = int(vec_hidden.shape[0])
    if want != got:
        raise ValueError(
            f"[{where}] Steering vector width mismatch: vec={got}, model_hidden={want}. "
            "Ensure you invert PCA+scaler with the SAME model_npz used to build the stats."
        )

# --------------- Schedules (linear-only) --------------
def make_schedule(kind: str, layers: List[int], alpha: float) -> Dict[int,float]:
    if not layers:
        return {}
    if kind != "linear":
        kind = "linear"
    if len(layers) == 1:
        return {layers[0]: alpha}
    L0, L1 = layers[0], layers[-1]
    a0, a1 = 0.2 * alpha, alpha
    den = max(1, (L1 - L0))
    return {L: (1 - (L - L0)/den) * a0 + ((L - L0)/den) * a1 for L in layers}

# --------------- Step-aware gate (batched) --------------
class BatchStepGate:
    def __init__(self, batch_size: int):
        self.apply_now = [False] * batch_size

class BatchNewlineWatcher:
    """
    A logits-processor: arms the gate for a batch slot when the generated text ends with a double newline.
    """
    def __init__(self, tokenizer, gate: BatchStepGate):
        self.tok = tokenizer
        self.gate = gate
        self.nn_ids = self.tok("\n\n", add_special_tokens=False).input_ids
        self.n_ids  = self.tok("\n",   add_special_tokens=False).input_ids

    def _ends_with(self, ids, pat):
        L = len(pat)
        return L > 0 and len(ids) >= L and ids[-L:] == pat

    def __call__(self, input_ids, scores):
        B = input_ids.shape[0]
        for b in range(B):
            ids = input_ids[b].tolist()
            arm = False
            if len(self.nn_ids) == 1 and self._ends_with(ids, self.nn_ids):
                arm = True
            elif len(self.n_ids) == 1 and self._ends_with(ids, self.n_ids + self.n_ids):
                arm = True
            else:
                tail = self.tok.decode(ids[-8:], skip_special_tokens=False)
                if "\n\n" in tail and tail.rstrip().endswith("\n\n"):
                    arm = True
            if arm:
                self.gate.apply_now[b] = True
        return scores

def _make_batched_gated_add_hook(vec: torch.Tensor, a: float, gate: BatchStepGate):
    def _apply(h: torch.Tensor) -> torch.Tensor:
        if h.dim() == 2:  # [T, H] (causal LM returning last_hidden_state per token)
            mask = torch.tensor([gate.apply_now[0]], device=h.device, dtype=h.dtype)
            v = (a * vec).to(h.device, dtype=h.dtype)
            h[-1, :] = h[-1, :] + mask * v
            gate.apply_now[0] = False
            return h

        B, T, H = h.shape  # [B, T, H]
        v = (a * vec).to(h.device, dtype=h.dtype)
        m = torch.tensor(gate.apply_now[:B], device=h.device, dtype=h.dtype).unsqueeze(-1)
        h[:, -1, :] = h[:, -1, :] + m * v
        for b in range(B):
            if gate.apply_now[b]:
                gate.apply_now[b] = False
        return h

    def hook(_mod, _inp, out):
        if torch.is_tensor(out): return _apply(out)
        if isinstance(out, tuple) and len(out) > 0 and torch.is_tensor(out[0]):
            hs = _apply(out[0]); return (hs, *out[1:])
        if hasattr(out, "last_hidden_state") and torch.is_tensor(out.last_hidden_state):
            out.last_hidden_state = _apply(out.last_hidden_state); return out
        return out
    return hook

# --------------- Runner --------------
def make_gen(device: torch.device, seed: int) -> torch.Generator:
    g = torch.Generator(device=device); g.manual_seed(int(seed)); return g

class LLMRunner:
    def __init__(self, model_name: str, tokenizer_name: Optional[str],
                 temperature: float=0.0, top_p: float=1.0, max_new_tokens: int=256,
                 device: str = "auto", dtype: str = "float16", top_k: Optional[int] = None,
                 use_nemotron_chat: bool = False, system_text: str = "detailed thinking on",
                 final_boxed_hint: bool = False, use_qwen_chat: bool = False, qwen_enable_thinking: bool = True,
                 min_p: Optional[float] = None):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.temperature = float(temperature); self.top_p = float(top_p)
        self.top_k = (int(top_k) if top_k is not None and int(top_k) > 0 else None)
        self.max_new_tokens = int(max_new_tokens)
        self.tok, self.mdl = load_tok_mdl(model_name, tokenizer_name, device, dtype)
        self.min_p = (float(min_p) if min_p is not None else None)
        self.use_nemotron_chat = bool(use_nemotron_chat)
        self.system_text = system_text
        self.final_boxed_hint = bool(final_boxed_hint)
        self.use_qwen_chat = bool(use_qwen_chat)
        self.qwen_enable_thinking = bool(qwen_enable_thinking)

    def _format_prompts(self, prompts: List[str]) -> List[str]:
        ANSWER_IN_BOX_PROMPT = (
            "Answer the following question step-by-step.\n"
            "At the very end, output exactly one line formatted as:\n"
            "Final Answer: \\boxed{...}\n"
        )
        outs = []
        for p in prompts:
            user_text = p
            if self.final_boxed_hint:
                user_text = f"{ANSWER_IN_BOX_PROMPT}\n{user_text.rstrip()}"
            if self.use_nemotron_chat:
                outs.append(nemotron_build_prompt(self.tok, self.system_text, user_text))
            elif self.use_qwen_chat:
                outs.append(qwen_build_prompt(self.tok, user_text=user_text, enable_thinking=self.qwen_enable_thinking))
            else:
                outs.append(user_text)
        return outs

    def _build_gen_kwargs(self, batch_inputs, do_sample: bool):
        eos_id = self.tok.eos_token_id
        pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else eos_id
        gen_kwargs = dict(
            max_new_tokens=self.max_new_tokens,
            do_sample=do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
            eos_token_id=eos_id,
            pad_token_id=pad_id,
        )
        if do_sample and self.top_k is not None:
            try:
                if _supports_top_k(self.mdl):
                    gen_kwargs["top_k"] = int(self.top_k)
                else:
                    if _only_rank0():
                        warnings.warn(f"Model '{self.model_name}' does not support top_k; ignoring --gen_top_k.")
            except Exception:
                if _only_rank0():
                    warnings.warn(f"Could not set top_k for model '{self.model_name}'; ignoring.")
        if do_sample and self.min_p is not None:
            try:
                if _supports_min_p(self.mdl):
                    gen_kwargs["min_p"] = float(self.min_p)
                else:
                    if _only_rank0():
                        warnings.warn(f"Model '{self.model_name}' does not support min_p; ignoring --min_p.")
            except Exception:
                if _only_rank0():
                    warnings.warn(f"Could not set min_p for model '{self.model_name}'; ignoring.")
        gen_kwargs.update(batch_inputs)
        return gen_kwargs

    def _find_blocks(self):
        blocks = (getattr(getattr(self.mdl, "model", None), "layers", None)
               or getattr(getattr(self.mdl, "transformer", None), "h", None))
        if blocks is None:
            raise RuntimeError("Unsupported model structure for hooking blocks (no .model.layers or .transformer.h).")
        return blocks

    def _register_hooks(self, schedule: Dict[int,float], vec_hidden: np.ndarray,
                        batch_size: int, step_aware: bool):
        _validate_vec_dim(vec_hidden, self.mdl, where="register_hooks")

        blocks = self._find_blocks()
        n_layers = len(blocks)
        if n_layers <= 0:
            raise RuntimeError("Model has zero transformer blocks?")

        valid = []
        for L, a in schedule.items():
            if not isinstance(L, int):
                warnings.warn(f"[hooks] Layer index {L} is not int; skipping.")
                continue
            if 0 <= L < n_layers:
                valid.append((L, a))
            else:
                warnings.warn(f"[hooks] Layer {L} out of range [0, {n_layers-1}]; skipping.")

        if not valid and _only_rank0():
            warnings.warn("[hooks] No valid layers to hook after range checks.")

        handles = []
        gate = BatchStepGate(batch_size) if step_aware else None
        v_gpu = torch.tensor(vec_hidden, device=_model_device(self.mdl), dtype=next(self.mdl.parameters()).dtype)

        for L, a in valid:
            if step_aware:
                handles.append(blocks[L].register_forward_hook(
                    _make_batched_gated_add_hook(v_gpu, a, gate)
                ))
            else:
                def _make_add(vec_t, alpha):
                    def _hook(_m, _i, out):
                        def _add(h):
                            return h + alpha * vec_t.to(h.device, dtype=h.dtype)
                        if torch.is_tensor(out): return _add(out)
                        if isinstance(out, tuple) and len(out)>0 and torch.is_tensor(out[0]):
                            hs = _add(out[0]); return (hs, *out[1:])
                        if hasattr(out, "last_hidden_state") and torch.is_tensor(out.last_hidden_state):
                            out.last_hidden_state = _add(out.last_hidden_state); return out
                        return out
                    return _hook
                handles.append(blocks[L].register_forward_hook(_make_add(v_gpu, a)))

        if _only_rank0() and valid:
            hooked_layers = ", ".join(str(L) for (L, _) in valid)
            print(f"[hooks] Registered {len(valid)} hooks on layers: {hooked_layers}")

        return handles, gate

    @torch.inference_mode()
    def generate_batched(self, prompts: List[str], schedule: Optional[Dict[int, float]] = None,
                        vec_hidden: Optional[np.ndarray] = None, step_aware: bool = True,
                        torch_generator: Optional[torch.Generator] = None) -> Tuple[List[str], List[int]]:
        dev = _model_device(self.mdl)
        prompts_fmt = self._format_prompts(prompts)
        batch_inputs = self.tok(prompts_fmt, return_tensors="pt", padding=True).to(dev)
        T_in_max = int(batch_inputs["input_ids"].size(1))
        eos_id = self.tok.eos_token_id

        do_sample = (self.temperature > 0.0) or (self.top_p < 1.0) or (self.top_k is not None)

        handles, gate = ([], None)
        if schedule and vec_hidden is not None:
            handles, gate = self._register_hooks(schedule, vec_hidden, batch_size=len(prompts_fmt), step_aware=step_aware)

        try:
            gen_kwargs = self._build_gen_kwargs(batch_inputs, do_sample=do_sample)

            # ---- Wire the step-aware gate into generation so it actually fires ----
            if gate is not None and step_aware:
                try:
                    from transformers import LogitsProcessorList
                    lps = gen_kwargs.get("logits_processor", None)
                    if lps is None:
                        lps = LogitsProcessorList()
                    lps.append(BatchNewlineWatcher(self.tok, gate))
                    gen_kwargs["logits_processor"] = lps
                    if _only_rank0():
                        print("[hooks] Step-aware gating: logits processor attached.")
                except Exception as e:
                    if _only_rank0():
                        warnings.warn(f"[hooks] Could not attach BatchNewlineWatcher: {e}. Step-aware will not fire.")

            if do_sample and (torch_generator is not None):
                try:
                    sig = inspect.signature(self.mdl.generate)
                    if "generator" in sig.parameters:
                        gen_kwargs["generator"] = torch_generator
                    elif "torch_generator" in sig.parameters:
                        gen_kwargs["torch_generator"] = torch_generator
                except Exception:
                    pass

            out_ids = self.mdl.generate(**gen_kwargs)
            if out_ids.ndim == 1:
                out_ids = out_ids.unsqueeze(0)

            texts, gen_tokens = [], []
            for b in range(out_ids.size(0)):
                seq = out_ids[b].tolist()
                start = T_in_max
                end = len(seq)
                if eos_id is not None:
                    for t in range(start, end):
                        if seq[t] == eos_id:
                            end = t + 1
                            break
                texts.append(self.tok.decode(out_ids[b], skip_special_tokens=True))
                gen_tokens.append(end - start)
            return texts, gen_tokens
        finally:
            for h in handles:
                h.remove()

# ---------------- Dataset utils ----------------
def _normalize_text(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[^a-z0-9 .,:;@#%/\\-+*=()]", "", s)
    return s

def _extract_number(s: str):
    m = re.search(r"####\s*([-+]?(?:\d+(?:\.\d+)?(?:e[-+]?\d+)?))", s, flags=re.I)
    if m: return m.group(1)
    nums = re.findall(r"[-+]?(?:\d+(?:\.\d+)?(?:e[-+]?\d+)?)", s, flags=re.I)
    return nums[-1] if nums else None

LETTERS = ["A", "B", "C", "D"]
LETTER_RE = re.compile(r"(?i)(?:Final Answer\s*:\s*)?(?:\\boxed\{|\b)([A-D])(?:\}|\.|\b)")

def _last_nonempty_line(s: str) -> str:
    for line in reversed(s.splitlines()):
        line = line.strip()
        if line:
            return line
    return ""

def _pick_letter(text: str):
    if not text:
        return None
    m = LETTER_RE.search(text)
    return m.group(1).upper() if m else None

def _extract_answer_ref(ref: str, metric: str, regex_answer: Optional[str]):
    if metric == "numeric":
        return _extract_number(ref)
    final_line = _last_nonempty_line(ref)
    return _pick_letter(final_line)

def _extract_answer_pred(pred: str, metric: str, regex_pred: Optional[str]):
    if metric == "numeric":
        return _extract_number(pred)
    final_line = _last_nonempty_line(pred)
    return _pick_letter(final_line)

# def _grade(pred: str, ref: str, metric: str="em", regex_answer: Optional[str]=None, regex_pred: Optional[str]=None) -> bool:
#     a = _extract_answer_pred(pred, metric, regex_pred)
#     b = _extract_answer_ref(ref,  metric, regex_answer)
#     if a is None or b is None: return False
#     if metric == "numeric":
#         try: return float(a) == float(b)
#         except Exception: return a == b
#     return a == b

# helper (top-level with your other helpers)
def _pick_boxed(s: str):
    m = re.search(r'\\boxed\{(.+?)\}', s, flags=re.DOTALL)
    return m.group(1).strip() if m else None

def _grade(pred: str, ref: str, metric: str="em",
           regex_answer: Optional[str]=None, regex_pred: Optional[str]=None) -> bool:
    # ---- ARC-only path: ref is a JSON blob with gold_letter + text2label ----
    try:
        obj = json.loads(ref)
        if isinstance(obj, dict) and "gold_letter" in obj and "text2label" in obj:
            gold = (obj.get("gold_letter") or "").strip().upper()
            t2l  = { (k or "").strip().lower(): (v or "").strip().upper()
                     for k, v in (obj.get("text2label") or {}).items() }

            final_line = _last_nonempty_line(pred)
            m = LETTER_RE.search(final_line)
            if m:
                return (m.group(1).upper() == gold)

            raw = _pick_boxed(pred) or final_line
            key = raw.strip().lower()
            if key in t2l:
                return (t2l[key] == gold)

            if t2l:
                cand = difflib.get_close_matches(key, list(t2l.keys()), n=1, cutoff=0.90)
                if cand:
                    return (t2l[cand[0]] == gold)
            return False
    except Exception:
        pass

    # ---- Default path (unchanged) — used by GPQA diamond with metric='regex' ----
    a = _extract_answer_pred(pred, metric, regex_pred)
    b = _extract_answer_ref(ref,  metric, regex_answer)
    if a is None or b is None: return False
    if metric == "numeric":
        try: return float(a) == float(b)
        except Exception: return a == b
    return a == b

def _autodetect_keys(r: dict) -> Tuple[str,str]:
    pkeys = ["prompt","question","input","query"]; akeys = ["answer","gold","target","gt_answer","reference"]
    pkey = next((k for k in pkeys if k in r), None); akey = next((k for k in akeys if k in r), None)
    if not pkey or not akey: raise KeyError("Autodetect failed; set --pt_prompt_key/--pt_answer_key")
    return pkey, akey

def _load_hf_dataset_items(
    ds_name: str,
    ds_config: Optional[str] = None,
    split: str = "test",
    prompt_key: str = "question",
    answer_key: str = "answer",
    n: Optional[int] = None,
    seed: int = 0,
    skip_first: int = 0,
    filter_answer_types: Optional[List[str]] = None,
    filter_difficulties: Optional[List[str]] = None,
) -> List[Tuple[str, str]]:
    from datasets import load_dataset
    ds = load_dataset(ds_name, ds_config, split=split) if ds_config else load_dataset(ds_name, split=split)

    if filter_answer_types is not None and "answer_type" in ds.column_names:
        allow = set(x.strip() for x in filter_answer_types)
        ds = ds.filter(lambda r: r.get("answer_type") in allow)
    if filter_difficulties is not None and "difficulty" in ds.column_names:
        allow = set(x.strip() for x in filter_difficulties)
        ds = ds.filter(lambda r: r.get("difficulty") in allow)

    N = len(ds)
    start = min(skip_first, N)
    pool = list(range(start, N))
    if n and n < len(pool):
        rng = np.random.default_rng(seed); rng.shuffle(pool); pool = pool[:n]
    else:
        pool = pool[:n]

    items = []
    for i in pool:
        rec = ds[i]
        if prompt_key not in rec or answer_key not in rec:
            pkey, akey = _autodetect_keys(rec)
            items.append((str(rec[pkey]).strip(), str(rec[akey]).strip()))
        else:
            items.append((str(rec[prompt_key]).strip(), str(rec[answer_key]).strip()))
    return items

def _load_gpqa_diamond_items(
    split: str = "train",
    n: Optional[int] = 100,
    seed: int = 0,
    skip_first: int = 0,
) -> List[Tuple[str, str]]:
    from datasets import load_dataset
    import random

    ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split=split)
    N = len(ds)
    start = min(skip_first, N)

    pool = list(range(start, N))
    if n is not None and n < len(pool):
        rng = np.random.default_rng(seed); rng.shuffle(pool); pool = pool[:n]
    else:
        pool = pool[:n]

    LETTERS = ["A", "B", "C", "D"]
    BOX_HINT = (
        "You are answering a 4-option multiple-choice question.\n"
        "Options are labeled A, B, C, and D.\n"
        "Think step-by-step and show your reasoning.\n"
        "At the very end, output ONE line exactly in this format:\n"
        "Final Answer: \\boxed{A}\n"
        "where the letter is A, B, C, or D.\n"
    )

    items: List[Tuple[str, str]] = []
    for abs_i in pool:
        r = ds[abs_i]
        q = str(r["Question"]).strip()
        opts = [
            str(r["Correct Answer"]).strip(),
            str(r["Incorrect Answer 1"]).strip(),
            str(r["Incorrect Answer 2"]).strip(),
            str(r["Incorrect Answer 3"]).strip(),
        ]
        rng = random.Random(seed + abs_i)
        idxs = [0, 1, 2, 3]
        rng.shuffle(idxs)
        shuf = [opts[i] for i in idxs]
        correct_idx = idxs.index(0)
        correct_letter = LETTERS[correct_idx]

        options_block = "\n".join(f"{LETTERS[j]}. {shuf[j]}" for j in range(4))
        prompt = f"{BOX_HINT}\n{q}\n\n{options_block}\n"
        items.append((prompt, correct_letter))
    return items

# ---------------- Multi-vec selection helpers ----------------
def _load_soft_json(path: str) -> Tuple[List[Tuple[int,int]], List[float]]:
    with open(path, "r") as f:
        obj = json.load(f)
    edges = [(int(i), int(j)) for i, j in obj.get("edges", [])]
    weights = [float(w) for w in obj.get("weights", [])]
    if len(edges) != len(weights):
        weights = [1.0/len(edges)] * len(edges) if edges else []
    return edges, weights

def _find_vec_key_for_edge(all_keys: List[str], prefix: str, i: int, j: int) -> Optional[str]:
    patterns = [
        f"vec::{prefix}:{i},{j}",
        f"vec::{prefix}:{i},{j}".replace("edge_delta", "edge"),
        f"vec::{prefix}:{i}_{j}",
        f"vec::{prefix}::{i},{j}",
    ]
    for p in patterns:
        if p in all_keys:
            return p
    for k in all_keys:
        if k.startswith("vec::") and (f"{i},{j}" in k or f"{i}_{j}" in k):
            return k
    return None

def build_vec_bank_from_soft(stats_npz_path: str, model_npz_path: str,
                             soft_json_path: str, prefix: str) -> Tuple[List[np.ndarray], List[float]]:
    z = np.load(stats_npz_path, allow_pickle=True)
    keys = list(z.files)
    edges, probs = _load_soft_json(soft_json_path)
    scaler, pca = load_preproc(model_npz_path)

    vecs = []
    for (i, j) in edges:
        k = _find_vec_key_for_edge(keys, prefix, i, j)
        if k is None:
            continue
        v = invert_preproc_step(z[k], scaler, pca)
        vecs.append(v)
    if vecs and probs and len(vecs) == len(probs):
        s = sum(probs)
        probs = [p/s for p in probs] if s > 0 else [1.0/len(vecs)]*len(vecs)
    else:
        probs = [1.0/len(vecs)]*len(vecs) if vecs else []
    return vecs, probs

# ---------------- Evaluation (BATCHED) ----------------
def evaluate_llm_accuracy_batched(
        runner: LLMRunner,
        items: List[Tuple[str,str]],
        metric: str = "em",
        schedule: Optional[Dict[int,float]] = None,
        vec_hidden: Optional[np.ndarray] = None,
        regex_answer: Optional[str] = None,
        regex_pred: Optional[str] = None,
        show_progress: bool = False,
        base_seed: Optional[int] = None,
        steered: bool = False,
        batch_size: int = 1,
        step_aware: bool = True,
        vec_bank: Optional[List[np.ndarray]] = None,
        vec_probs: Optional[List[float]] = None,
        selection_mode: str = "single",  # "single" | "prob" | "argmax" | "none"
    ):
    """
    Batched evaluation with optional multi-vec selection per batch.
    - If selection_mode == "none": no steering (ignores schedule/vecs)
    - If vec_hidden is provided: use it directly (single vector)
    - Else if vec_bank provided:
        * "single": use vec_bank[0]
        * "prob"  : sample index ~ vec_probs
        * "argmax": use argmax(vec_probs)
    """
    # Validate vector width early if provided
    if vec_hidden is not None:
        _validate_vec_dim(vec_hidden, runner.mdl, where="evaluate_llm_accuracy_batched(vec_hidden)")
    if vec_bank:
        for k, v in enumerate(vec_bank):
            _validate_vec_dim(v, runner.mdl, where=f"evaluate_llm_accuracy_batched(vec_bank[{k}])")

    correct, rows = 0, []
    total_gen_tokens = 0

    it = range(0, len(items), batch_size)
    if show_progress and _only_rank0():
        it = tqdm(it, total=(len(items)+batch_size-1)//batch_size, desc="Evaluating", unit="batch")

    rng = np.random.default_rng(base_seed if base_seed is not None else 0)

    for start in it:
        batch = items[start:start+batch_size]
        prompts = [p for (p, _) in batch]
        golds   = [g for (_, g) in batch]

        gen = None
        if base_seed is not None:
            gen = make_gen(_model_device(runner.mdl), int(base_seed))

        chosen_vec = None
        if selection_mode == "none":
            pass
        elif vec_hidden is not None:
            chosen_vec = vec_hidden
        elif vec_bank:
            if selection_mode == "single":
                chosen_vec = vec_bank[0]
            elif selection_mode == "prob":
                idx = rng.choice(len(vec_bank), p=(vec_probs if vec_probs else None))
                chosen_vec = vec_bank[idx]
            elif selection_mode == "argmax":
                idx = int(np.argmax(vec_probs)) if vec_probs else 0
                chosen_vec = vec_bank[idx]

        preds, gens = runner.generate_batched(
            prompts,
            schedule=(schedule if chosen_vec is not None else None),
            vec_hidden=chosen_vec,
            step_aware=step_aware,
            torch_generator=gen
        )

        for j, (pred, gold, gen_tokens) in enumerate(zip(preds, golds, gens)):
            ok = _grade(pred, gold, metric=metric, regex_answer=regex_answer, regex_pred=regex_pred)
            correct += int(ok)
            total_gen_tokens += int(gen_tokens)
            rows.append({
                "i": start + j,
                "prompt": prompts[j],
                "gold": gold,
                "pred": pred,
                "ok": bool(ok),
                "gen_tokens": int(gen_tokens),
            })

    acc = correct / max(1, len(items))
    avg_gen_tokens = (total_gen_tokens / max(1, len(items)))
    return acc, rows, total_gen_tokens, avg_gen_tokens
